# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# This file is part of ByteQC.
#
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https: // www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from byteqc.embyte.Framework.RDM1_building import make_RDM1_equi_pair_group
from byteqc.embyte.Setting import MIN_GROUP_SIZE, RDM_MEMORY_POOL_SIZE
import fnmatch
import shutil
from byteqc import lib
from functools import reduce
import traceback
from byteqc.embyte.Tools.fragment import Fragment_group
from byteqc.embyte.Tools.logger import Logger, Process_Record
import threading
from byteqc.cuobc.lib.int3c import VHFOpt3c
import gc
import json
from mpi4py import MPI
from byteqc.embyte.Framework.high_level_process import high_level_processing
from byteqc.embyte.Framework import low_level_process
import time
from byteqc.embyte.Localization import iao
from byteqc.embyte.Solver import GPU_CCSDSolver
import os
import sys
import cupyx
import multiprocessing
import pickle
import h5py
import numpy
import cupy
cupy.cuda.set_pinned_memory_allocator(None)


class SIE_kernel:
    '''
    Basic framework for SIE.
    '''

    def __init__(self, logfile, HF_chk):
        '''
        The parameters must be given:

        HF_chk: the HF checkpoint file generated by pyscf or gpu4pyscf or cuobc/cupbc.
        logfile: the folder for save the SIE output.
        '''
        self.logfile = logfile
        self.HF_chk = HF_chk
        '''
        The parameters suggested to be set:

        electronic_structure_solver: only MP2/CCSD/CCSD(T) are accessible for now.
        electron_localization_method: the electron localization method, iao would be highly recommanded..
        threshold: the threshold for the BNO.
        aux_basis: the aux basis for the ERI generation in SIE workflow.
                    It should be the same with the HF calculation if HF used DF.
        oei_file: the oei path dumped by pyscf.scf.get_hcore.
        jk_file: the jk path dumped by pyscf.scf.get_veff().
        RDM: if True, the global 1-RDM and in-cluster 2-RDM will be calculated.
        in_situ_T: if True and the high-level solver is CCSD, the in situ perturbative (T) correction would be done.
        '''
        self.electronic_structure_solver = GPU_CCSDSolver
        self.electron_localization_method = iao
        self.threshold = 1e-8
        self.aux_basis = None
        self.oei_file = None
        self.jk_file = None
        self.RDM = False
        self.in_situ_T = True
        '''
        The parameters should be set corresponding to the requirement:

        cheat_th: the threshold could be spceified some spceific clusters
                    with the form of cheat_th = {cluster_id: threshold}.
        eri_file: if eri_file is given, then the ERI will be loaded then
                    calculated the sunbspace/cluster ERIs. If the calculation
                    is done on OBC, it would just suggested to be set as None.
                    if If the calculation is done on PBC, then eri_file must be given.
                    And only CDERI is supported for now.
        '''
        self.cheat_th = None
        self.eri_file = None
        '''
        The parameters suggested to not be set and possibly be removed in the future.
        '''
        self.RDM_solved_list = None
        self.chk_point = True
        self.pair_group_size = 500
        self.t2_y_buffer_pool_size = RDM_MEMORY_POOL_SIZE
        self.local_orb_path = None

    def simulate(self, molecule, mean_field, fragments):

        if 'CCSD' in self.electronic_structure_solver.__name__ and self.RDM:
            self.if_l2 = True
        else:
            self.if_l2 = False

        comm = MPI.COMM_WORLD
        rank = comm.Get_rank()
        size = comm.Get_size()

        comm2 = MPI.COMM_WORLD

        if type(fragments) is not list:
            raise ValueError

        orb_list = []
        equivalent_list = []
        for frag in fragments:
            orb_list.append(frag['fragment_orb_list'])
            equivalent_list.append(frag['equivalent_level'])

        if all(value is None for value in equivalent_list):
            equivalent_list = [i for i in range(len(orb_list))]

        self.orb_list = orb_list
        self.equivalent_list = equivalent_list
        self.fragments = fragments

        equi_part_tmp = []

        for ind_i, level in enumerate(self.equivalent_list):
            while True:
                try:
                    equi_part_tmp[level].append(ind_i)
                except IndexError:
                    equi_part_tmp.append([])
                else:
                    break

        if self.cheat_th is not None:
            for part_list in equi_part_tmp:
                for clu_ind in part_list:
                    if clu_ind in self.cheat_th.keys():
                        for clu_ind_t in part_list:
                            self.cheat_th[clu_ind_t] = self.cheat_th[clu_ind]
                        break

        if isinstance(self.threshold, float):
            self.threshold = [[self.threshold, self.threshold]]
        elif isinstance(self.threshold, list):
            if all(isinstance(element, float) for element in self.threshold):
                self.threshold = [[element, element] for element in self.threshold]
            elif all(isinstance(element, list) for element in self.threshold):
                for th_tmp in self.threshold:
                    assert len(th_tmp) == 2, \
                        f'The BNO threshold only support for occ and vir, but get {len(th_tmp)} thresholds!'
                    assert all(isinstance(element, float) for element in th_tmp), \
                        f'The BNO threshold only support to be float!'
        else:
            assert False, 'The BNO threshold must be set as float or list with float'

        if self.cheat_th is not None:
            for clu_ind in self.cheat_th.keys():
                val = self.cheat_th[clu_ind]
                if isinstance(val, float):
                    self.cheat_th[clu_ind] = [[val, val]]
                elif isinstance(val, list):
                    if all(isinstance(element, float) for element in val):
                        self.cheat_th[clu_ind] = [[element, element] for element in val]
                    elif all(isinstance(element, list) for element in val):
                        for th_tmp in val:
                            assert len(th_tmp) == 2, \
                                f'The BNO threshold only support for occ and vir, but get {len(th_tmp)} thresholds!'
                            assert all(isinstance(element, float) for element in th_tmp), \
                                f'The BNO threshold only support to be float!'
                else:
                    assert False, 'The BNO threshold must be set as float or list with float'

        if rank == 0:

            if not os.path.exists(self.logfile):
                os.mkdir(self.logfile)

            self.PR = Process_Record(
                self.logfile + '/recorder', self.chk_point)
            self.LG = Logger(self.logfile + '/main_logger.log')

            if self.cheat_th is not None:
                self.LG.logger.info(
                    '------------ The cheat threshold is %s' %
                    self.cheat_th)
                self.LG.logger.info(
                    '------------ The original threshold is %s' %
                    self.threshold)
            else:
                self.LG.logger.info(
                    f'--------- the threshold : {self.threshold}')

            self.lock = threading.Lock()

            self.LG.logger.info('orb_list: %s' % orb_list)
            self.LG.logger.info('eq_list: %s' % self.equivalent_list)
            self.LG.logger.info(
                '=== Run HF for full system and build low_level_info')

            if self.aux_basis is not None:
                self.LG.logger.info(
                    f'--- The aux basis used %s' %
                    self.aux_basis)

            try:
                if self.PR.recorder['ProgramDone']:
                    self.LG.logger.info(
                        f'This program is done! Please check file {self.PR.filename}')
            except BaseException:
                self.PR.recorder['ProgramDone'] = False
                self.PR.save()

            if self.PR.recorder['ProgramDone']:
                sys.exit()

            if self.PR.recorder['HF_chkfile']:
                self.LG.logger.info('Load chkfile for mean_field!')
                self.LG.logger.info(
                    'check file path: %s' %
                    self.PR.recorder['HF_chkfile'])
                from pyscf import scf
                mean_field = scf.RHF(molecule)
                mean_field.verbose = 4
                try:
                    mean_field.__dict__.update(scf.chkfile.load(
                        self.PR.recorder['HF_chkfile'], 'scf'))
                    self.LG.logger.info(
                        '=== Load Check Point: load mean_field for full system')
                except BaseException:
                    pass

            else:
                if type(mean_field) is str:
                    self.HF_chk = mean_field
                    self.PR.recorder['HF_chkfile'] = self.HF_chk
                    self.LG.logger.info('Load chkfile from a given path!')
                    self.LG.logger.info(
                        'check file path: %s' %
                        self.PR.recorder['HF_chkfile'])

                    if getattr(molecule, 'pbc_intor', None):
                        from pyscf.pbc import scf as pbcscf
                        mean_field = pbcscf.RHF(molecule, exxdiv=None)
                        mean_field.verbose = 4
                        mean_field.__dict__.update(pbcscf.chkfile.load(
                            self.PR.recorder['HF_chkfile'], 'scf'))
                    else:
                        from pyscf import scf
                        mean_field = scf.RHF(molecule)
                        mean_field.verbose = 4
                        mean_field.__dict__.update(scf.chkfile.load(
                            self.PR.recorder['HF_chkfile'], 'scf'))

                    self.PR.save()

                else:
                    assert False, 'Expect mean_field to be a string path of the converged HF ' \
                        + f'checkpoint file given by pyscf! But a {type(mean_field)} is given!'

            self.LG.logger.info(
                '=== mean_field checkpoint file has been loaded')

            self.LG.logger.info(
                '=== converged SCF energy = %s' %
                mean_field.e_tot)

            if self.PR.recorder['fragment_group']:
                self.LG.logger.info(
                    'Load fragment_group class from %s' %
                    self.PR.recorder['fragment_group'])
                fragment_group = self.PR.load_class(
                    self.PR.recorder['fragment_group'])
                self.fragment_group = fragment_group
            else:
                fragment_group = Fragment_group(molecule, fragments)
                fragment_group.build()
                self.fragment_group = fragment_group
                self.fragment_group.gourp_pair(group_size=self.pair_group_size)
                if self.PR.chk_point:
                    self.PR.recorder['fragment_group'] = os.path.join(
                        self.logfile, 'fragment_group.pkl')
                    self.LG.logger.info(
                        'save fragment_group class in %s' %
                        self.PR.recorder['fragment_group'])
                    self.PR.save_class(
                        fragment_group, self.PR.recorder['fragment_group'])
                    self.PR.save()

            if self.PR.recorder['low_level_info_class']:
                self.LG.logger.info(
                    'Load low_level_info class from %s' %
                    self.PR.recorder['low_level_info_class'])
                low_level_info = self.PR.load_class(
                    self.PR.recorder['low_level_info_class'])
                low_level_info.mol_full = molecule
                del low_level_info.mol_full.stdout
                self.low_level_info = low_level_info

            else:
                self.LG.logger.info('=== buld low_level_info class')
                low_level_info = low_level_process.low_level_info(molecule,
                                                                  mean_field,
                                                                  self.LG,
                                                                  aux_basis=self.aux_basis,
                                                                  jk_file=self.jk_file,
                                                                  with_eri=False if self.eri_file is None else True,
                                                                  oei=self.oei_file,
                                                                  local_orb_path=self.local_orb_path,
                                                                  )
                if hasattr(low_level_info.mol_full, 'stdout'):
                    del low_level_info.mol_full.stdout
                del mean_field
                self.low_level_info = low_level_info

                Fock_MO = reduce(
                    cupy.dot, (cupy.asarray(
                        low_level_info.LOMO.T), cupy.asarray(
                        low_level_info.fock_LO), cupy.asarray(
                        low_level_info.LOMO))).get()
                numpy.save(
                    os.path.join(
                        self.PR.filepath,
                        'Fock_MO.npy'),
                    Fock_MO)
                assert numpy.isclose(low_level_info.mol_full.nelectron % 2, 0)
                nocc_full = round(low_level_info.mol_full.nelectron // 2)
                if not numpy.isclose(low_level_info.mol_full.nao - nocc_full,
                                     low_level_info.LOMO.shape[1] - nocc_full):
                    self.LG.logger.info(
                        f'The MO has been cut from {low_level_info.mol_full.nao} to {low_level_info.LOMO.shape[1]}.'
                        f' It may related to the linear dependency of the basis.')
                nvir_full = low_level_info.LOMO.shape[1] - nocc_full
                self.PR.recorder['nocc_nvir_full'] = [
                    int(nocc_full), int(nvir_full)]
                self.PR.save()
                del Fock_MO

                if self.PR.chk_point:
                    self.PR.recorder['low_level_info_class'] = os.path.join(
                        self.logfile, 'low_level_info_class.pkl')
                    self.LG.logger.info(
                        'save low_level_info class in %s' %
                        self.PR.recorder['low_level_info_class'])

                    self.PR.save_class(
                        low_level_info, self.PR.recorder['low_level_info_class'])
                    self.PR.save()
            self.nocc_full = self.PR.recorder['nocc_nvir_full'][0]
            self.nvir_full = self.PR.recorder['nocc_nvir_full'][1]
            self.nao = self.nocc_full + self.nvir_full
            low_scf_energy = low_level_info.low_scf_energy
            j2c_file = low_level_info.j2c

            del low_level_info
            gc.collect()

            self.LG.logger.info(
                '=== HF and build low_level_info part finished')

            self.cluster_path = os.path.join(self.logfile, 'Cluster')
            if not os.path.exists(self.cluster_path):
                os.mkdir(self.cluster_path)

            high_level_frag = high_level_processing(
                orb_list,
                self.PR.recorder['low_level_info_class'],
                self.electronic_structure_solver,
                self.threshold,
                self.cluster_path)
            high_level_frag.RDM = self.RDM
            high_level_frag.fragments = self.fragments
            high_level_frag.equivalent_list = self.equivalent_list
            high_level_frag.cheat_th = self.cheat_th
            high_level_frag.eri = self.eri_file
            high_level_frag.in_situ_T = self.in_situ_T
            try:
                high_level_frag.logfile = self.cluster_path
            except BaseException:
                pass

            energy = low_scf_energy

            self.LG.logger.info('---- Start to bcast from main node')

            comm.bcast([orb_list,
                        self.PR.recorder['low_level_info_class'],
                        self.cluster_path],
                       root=0)
            self.LG.logger.info('---- bcast finished')

            if not self.PR.recorder['energy']:
                self.PR.recorder['energy'] = self.energy = energy

                if type(self.threshold) is list:
                    self.PR.recorder['frag_CE'] = self.frag_CE = [
                        0] * len(self.threshold)
                    self.PR.recorder['used_orb_num'] = self.used_orb_num = [
                        0] * len(self.threshold)
                    self.PR.recorder['energy'] = self.energy = [
                        energy] * len(self.threshold)
                else:
                    self.PR.recorder['frag_CE'] = self.frag_CE = [0]
                    self.PR.recorder['used_orb_num'] = self.used_orb_num = [0]
                    self.PR.recorder['energy'] = self.energy = [energy]
                self.PR.save()
            else:
                if self.PR.chk_point:
                    self.energy = self.PR.recorder['energy']
                    self.used_orb_num = self.PR.recorder['used_orb_num']
                    self.frag_CE = self.PR.recorder['frag_CE']

            equi_part = []

            for ind_i, level in enumerate(self.equivalent_list):
                while True:
                    try:
                        equi_part[level].append(ind_i)
                    except IndexError:
                        equi_part.append([])
                    else:
                        break

            if self.RDM_solved_list is not None:
                self.RDM_solved_group = []
                temp_equi = numpy.asarray(self.equivalent_list)[
                    self.RDM_solved_list].tolist()
                for ind_i, level in enumerate(temp_equi):
                    while True:
                        try:
                            self.RDM_solved_group[level].append(ind_i)
                        except IndexError:
                            self.RDM_solved_group.append([])
                        else:
                            break
            else:
                self.RDM_solved_group = equi_part

            for equi_frag in equi_part:
                if self.fragments[equi_frag[0]
                                  ]['equivalent_operator'] != ['main']:
                    raise KeyboardInterrupt(
                        'The 1st equivalent fragment must be the main fragment!')

                equiv_frag_op = [
                    self.fragments[tmp_ind]['equivalent_operator'] for tmp_ind in equi_frag]
                main_frag_count = len(
                    [index for index, value in enumerate(equiv_frag_op) if value == ['main']])
                if main_frag_count != 1:
                    raise KeyboardInterrupt(
                        'The equivalent group %s has %s main fragment which should be only 1!!' %
                        (equi_frag, main_frag_count))

            self.running_list = [0] * size
            self.equi_part = equi_part
            thread_list = []

            def wait_return(node_ind, conn, part_list):
                if node_ind == 0:
                    frag_corr_test, used_orb_num_temp_test = conn.recv()
                else:
                    frag_corr_test, used_orb_num_temp_test = comm.recv(
                        source=node_ind, tag=0)

                if frag_corr_test is False:
                    self.lock.acquire()
                    self.running_list[node_ind] = 0
                    self.LG.logger.info('+++ Cluster %s failed ! +++' %
                                        (self.equi_part[self.equi_part.index(part_list)]))
                    self.lock.release()
                else:
                    self.lock.acquire()
                    self.running_list[node_ind] = 0

                    self.used_orb_num = numpy.asarray(self.used_orb_num)
                    self.used_orb_num += used_orb_num_temp_test * \
                        len(part_list)
                    self.used_orb_num = self.used_orb_num.tolist()

                    self.frag_CE = numpy.asarray(
                        self.frag_CE, dtype=numpy.float64)
                    self.frag_CE += frag_corr_test * len(part_list)
                    self.frag_CE = self.frag_CE.tolist()

                    self.PR.recorder['used_orb_num'] = self.used_orb_num
                    self.PR.recorder['frag_CE'] = self.frag_CE
                    self.PR.recorder['Cluster'][self.equi_part.index(
                        part_list)] = True
                    self.PR.save()

                    self.lock.release()

            if not self.PR.recorder['Cluster']:
                self.PR.recorder['Cluster'] = [False] * len(equi_part)
                self.PR.save()
            else:
                if self.PR.recorder['Cluster'] == [True] * len(equi_part):
                    self.LG.logger.info(
                        '----This system has already be sovled. Please check:%s' %
                        self.PR.filename)
                    for node_ind_kill in range(1, size):
                        comm.send(False, dest=node_ind_kill, tag=node_ind_kill)
                        if not self.RDM:
                            print('send kill single')
                            comm2.send(
                                False, dest=node_ind_kill, tag=node_ind_kill)
                    if not self.RDM:
                        sys.exit()

            if False in self.PR.recorder['Cluster']:
                self.LG.logger.info('Building vhfopt for all node!')
                if high_level_frag.vhfopt is None and high_level_frag.eri is None:
                    vhfopt = VHFOpt3c(
                        self.low_level_info.mol_full,
                        self.low_level_info.auxmol,
                        'int2e')
                    vhfopt.build(
                        group_size=MIN_GROUP_SIZE,
                        aux_group_size=MIN_GROUP_SIZE)
                    try:
                        vhfopt.auxcoeff = vhfopt.auxcoeff.get()
                    except BaseException:
                        pass
                    high_level_frag.vhfopt = vhfopt

                for equi_ind, part_list in enumerate(equi_part):

                    if self.PR.recorder['Cluster'][equi_ind]:
                        self.LG.logger.info(
                            '--- The cluster %s has already be solved! ---' %
                            equi_part[equi_ind][0])
                        continue

                    while not (0 in self.running_list):
                        time.sleep(1)
                    node_ind = self.running_list.index(0)
                    self.LG.logger.info(
                        '--- the cluster %s has been distributed to node %s ' %
                        (part_list[0], node_ind))

                    if node_ind == 0:
                        conn_recv, conn_send = multiprocessing.Pipe()
                        p1 = threading.Thread(target=task, args=(
                            high_level_frag, part_list, conn_send, ))
                        p1.start()
                    else:
                        comm.send(part_list, dest=node_ind, tag=node_ind)
                        conn_recv = None

                    self.running_list[node_ind] = 1
                    t = threading.Thread(
                        target=wait_return, args=(
                            node_ind, conn_recv, part_list))
                    t.start()
                    thread_list.append(t)

                for t in thread_list:
                    t.join()

                for node_ind_kill in range(1, size):
                    comm.send(False, dest=node_ind_kill, tag=node_ind_kill)

                try:
                    del high_level_frag.vhfopt, vhfopt
                except BaseException:
                    pass

            self.used_orb_num = numpy.asarray(
                self.used_orb_num) / len(self.equivalent_list)
            self.used_orb_num = self.used_orb_num.tolist()

            if False in self.PR.recorder['Cluster']:
                for node_ind_kill in range(1, size):
                    comm2.send(False, dest=node_ind_kill, tag=node_ind_kill)
                    self.LG.logger.info(
                        '----Cluster calculation is not finished!')
                raise KeyboardInterrupt('Cluster calculation is not finished!')
                sys.exit()

            if 'CCSD' in self.electronic_structure_solver.__name__ and self.in_situ_T:
                e_ccsd_t = [0] * len(self.threshold)
                for equi_i_group in self.equi_part:
                    for clu_i_main in equi_i_group:
                        if self.fragments[clu_i_main]['equivalent_operator'] == [
                                'main']:
                            break

                    filepath_clu_i_json = self.PR.filepath + \
                        '/Cluster/Cluster_%s/cluster_recorder' % (clu_i_main)
                    with open(filepath_clu_i_json, 'r') as f:
                        data = json.load(f)
                        e_ccsd_t += numpy.asarray(
                            data['T_correction']) * len(equi_i_group)

            if self.RDM:

                if False in self.PR.recorder['Cluster']:
                    for node_ind_kill in range(1, size):
                        comm2.send(
                            False, dest=node_ind_kill, tag=node_ind_kill)
                        self.LG.logger.info(
                            '----Cluster calculation is not finished!')
                    raise KeyboardInterrupt(
                        'Cluster calculation is not finished!')
                    sys.exit()

                if 'CCSD' in self.electronic_structure_solver.__name__:

                    try:
                        self.PR.recorder['global_t1']
                    except BaseException:
                        self.PR.recorder['global_t1'] = None
                        self.PR.recorder['global_l1'] = None
                        self.PR.save()

                    if self.PR.recorder['global_t1'] is None:

                        self.LG.logger.info('---- Get global t1 and l1')

                        nocc_full, nvir_full = self.PR.recorder['nocc_nvir_full']

                        path_list_t1 = []
                        path_list_l1 = []

                        for th_ind_tmp in range(len(self.threshold)):
                            t1_full = cupy.zeros((nocc_full, nvir_full))
                            l1_full = cupy.zeros((nocc_full, nvir_full))
                            for equi_i_group in self.equi_part:
                                for clu_i_main in equi_i_group:
                                    if self.fragments[clu_i_main]['equivalent_operator'] == [
                                            'main']:
                                        break

                                try:
                                    threshold_list = self.cheat_th[clu_i_main]
                                except BaseException:
                                    threshold_list = self.threshold

                                th_tmp = threshold_list[th_ind_tmp]
                                th_tmp = 'occ_1e%.1f_vir_1e%.1f' % \
                                    (numpy.log10(th_tmp[0]), numpy.log10(th_tmp[1]))

                                filepath_clu_i_main = self.PR.filepath + \
                                    '/Cluster/Cluster_%s/th_%s' % (
                                        clu_i_main, th_tmp)

                                with h5py.File(filepath_clu_i_main + '/t1') as f:
                                    t1_tmp = cupyx.zeros_pinned(f['t1'].shape)
                                    f['t1'].read_direct(t1_tmp)
                                    t1_tmp = cupy.asarray(t1_tmp)
                                with h5py.File(filepath_clu_i_main + '/l1') as f:
                                    l1_tmp = cupyx.zeros_pinned(f['l1'].shape)
                                    f['l1'].read_direct(l1_tmp)
                                    l1_tmp = cupy.asarray(l1_tmp)

                                for clu_i in equi_i_group:
                                    filepath_clu_i = self.PR.filepath + \
                                        '/Cluster/Cluster_%s/th_%s' % (
                                            clu_i, th_tmp)
                                    MO_occ_FRAG = cupy.asarray(
                                        numpy.load(filepath_clu_i + '/MO_occ_FRAG.npy'))
                                    CLU_occ_FRAG = cupy.asarray(
                                        numpy.load(filepath_clu_i + '/CLU_occ_FRAG.npy'))
                                    MO_occ_CLU_occ_FRAG = MO_occ_FRAG @ CLU_occ_FRAG.T
                                    MO_vir_CLU_vir = cupy.asarray(
                                        numpy.load(filepath_clu_i + '/MO_vir_CLU_vir.npy'))

                                    t1_full += reduce(cupy.dot,
                                                      (MO_occ_CLU_occ_FRAG,
                                                       t1_tmp,
                                                       MO_vir_CLU_vir.T))
                                    l1_full += reduce(cupy.dot,
                                                      (MO_occ_CLU_occ_FRAG,
                                                       l1_tmp,
                                                       MO_vir_CLU_vir.T))

                            t1_path = self.PR.filepath + \
                                f'/t1_{th_ind_tmp}_global.npy'
                            l1_path = self.PR.filepath + \
                                f'/l1_{th_ind_tmp}_global.npy'

                            numpy.save(t1_path, t1_full.get())
                            numpy.save(l1_path, l1_full.get())

                            path_list_t1.append(t1_path)
                            path_list_l1.append(l1_path)

                        self.PR.recorder['global_t1'] = path_list_t1
                        self.PR.recorder['global_l1'] = path_list_l1
                        self.PR.save()

                        self.LG.logger.info(
                            '--- Global t1 and l1 has been saved')

                try:
                    self.PR.recorder['RDM']
                except BaseException:
                    self.PR.recorder['RDM'] = [
                        False] * len(self.fragment_group.group_pair_relation_ship)
                    self.PR.recorder['1-RDM_energy'] = [0] * \
                        len(self.threshold)
                    self.PR.save()

                def wait_RDM_return(node_ind, equi_part,
                                    equi_pair_group, equi_pair_group_ind, conn):
                    if node_ind == 0:
                        e_corr_rdm1_result, doo_list, dvv_list, dov_list = conn.recv()
                    else:
                        e_corr_rdm1_result, doo_list, dvv_list, dov_list = comm2.recv(
                            source=node_ind, tag=0)

                    if e_corr_rdm1_result is False:
                        self.lock.acquire()
                        self.running_list[node_ind] = 0
                        self.LG.logger.info(
                            '+++ 1-RDM calculation in node %s failed! +++' %
                            (node_ind))
                        self.lock.release()
                    else:
                        self.lock.acquire()
                        self.running_list[node_ind] = 0

                        for equi_pair_ind, equi_pair in enumerate(
                                equi_pair_group):
                            rows, cols = zip(*equi_pair)
                            for th_ind_tmp in range(len(e_corr_rdm1_result)):
                                self.corr_energy_pair[th_ind_tmp, rows,
                                                      cols] = e_corr_rdm1_result[th_ind_tmp][equi_pair_ind]

                        for th_ind_tmp in range(len(e_corr_rdm1_result)):
                            self.RDM1[th_ind_tmp, :self.nocc_full,
                                      :self.nocc_full] += doo_list[th_ind_tmp]
                            self.RDM1[th_ind_tmp, self.nocc_full:,
                                      self.nocc_full:] += dvv_list[th_ind_tmp]
                            if dov_list is not None:
                                self.RDM1[th_ind_tmp, :self.nocc_full,
                                          self.nocc_full:] += dov_list[th_ind_tmp]
                                self.RDM1[th_ind_tmp, self.nocc_full:,
                                          :self.nocc_full] += dov_list[th_ind_tmp].T

                        self.corr_energy_pair.flush()
                        self.RDM1.flush()

                        self.PR.recorder['RDM'][equi_pair_group_ind] = True

                        self.PR.save()

                        self.LG.logger.info(
                            'Fragment pair group index %s has been solved.' %
                            equi_pair_group_ind)

                        self.lock.release()

                self.rdm1_energy = numpy.asarray(
                    self.PR.recorder['1-RDM_energy'], dtype=numpy.float64)
                thread_list_rdm = []

                file_corr_energy = os.path.join(
                    self.LG.filepath, 'corr_energy_mat.npy')
                file_RDM1 = os.path.join(self.LG.filepath, 'RDM1.npy')
                if os.path.exists(file_corr_energy):
                    self.corr_energy_pair = numpy.memmap(file_corr_energy, shape=(len(self.threshold), len(
                        self.fragments), len(self.fragments)), dtype='float64', mode='r+')
                    self.RDM1 = numpy.memmap(file_RDM1, shape=(
                        len(self.threshold), self.nao, self.nao), dtype='float64', mode='r+')
                else:
                    self.corr_energy_pair = numpy.memmap(file_corr_energy, shape=(len(self.threshold), len(
                        self.fragments), len(self.fragments)), dtype='float64', mode='w+')
                    self.RDM1 = numpy.memmap(file_RDM1, shape=(
                        len(self.threshold), self.nao, self.nao), dtype='float64', mode='w+')

                t2_y_buffer_pool = cupyx.zeros_pinned(
                    (self.t2_y_buffer_pool_size, ), dtype='float64')
                for equi_pair_group_ind, equi_pair_group in enumerate(
                        self.fragment_group.group_pair_relation_ship):
                    if self.PR.recorder['RDM'][equi_pair_group_ind]:
                        self.LG.logger.info(
                            '=== The equi_pair_group_ind %s has already been solved!' %
                            (equi_pair_group_ind))
                        continue

                    while not (0 in self.running_list):
                        time.sleep(1)
                    node_ind = self.running_list.index(0)

                    if node_ind == 0:
                        params = [
                            self.PR,
                            self.LG,
                            self.fragments,
                            self.equi_part,
                            equi_pair_group,
                            self.threshold,
                            t2_y_buffer_pool,
                            self.if_l2]
                        conn_recv, conn_send = multiprocessing.Pipe()
                        p1_rdm = threading.Thread(
                            target=task_rdm, args=(
                                params, self.cheat_th, conn_send, ))
                        p1_rdm.start()
                    else:
                        params = [
                            self.PR,
                            self.LG,
                            self.fragments,
                            self.equi_part,
                            equi_pair_group,
                            self.threshold,
                            self.if_l2]
                        comm2.send(params, dest=node_ind, tag=node_ind)
                        conn_recv = None
                    self.LG.logger.info(
                        'fragment pair group %s with %s pairs 1-RDM has been computed at node %s' %
                        (equi_pair_group_ind, len(equi_pair_group), node_ind))
                    self.running_list[node_ind] = 1
                    t_rdm = threading.Thread(
                        target=wait_RDM_return,
                        args=(
                            node_ind,
                            equi_part,
                            equi_pair_group,
                            equi_pair_group_ind,
                            conn_recv))
                    t_rdm.start()
                    thread_list_rdm.append(t_rdm)

                for t in thread_list_rdm:
                    t.join()
                    # t.close()
                for node_ind_kill in range(1, size):
                    comm2.send(False, dest=node_ind_kill, tag=node_ind_kill)
                del t2_y_buffer_pool

                try:
                    self.PR.recorder['build_RDM']
                except BaseException:
                    self.PR.recorder['build_RDM'] = False
                    self.PR.save()

                if not self.PR.recorder['build_RDM']:

                    if self.if_l2:

                        self.LG.logger.info(
                            'Using global t1 and l1 to get dov!')
                        RDM1_list = numpy.zeros(self.RDM1.shape)
                        for th_ind_tmp in range(len(self.threshold)):

                            t1_full = cupy.asarray(
                                numpy.load(self.PR.recorder['global_t1'][th_ind_tmp]))
                            l1_full = cupy.asarray(
                                numpy.load(self.PR.recorder['global_l1'][th_ind_tmp]))
                            s_vir = slice(self.nocc_full, self.nao)
                            s_occ = slice(0, self.nocc_full)

                            doo = cupy.asarray(
                                self.RDM1[th_ind_tmp][s_occ, s_occ])
                            dvv = cupy.asarray(
                                self.RDM1[th_ind_tmp][s_vir, s_vir])

                            dov = t1_full + l1_full + \
                                cupy.asarray(
                                    self.RDM1[th_ind_tmp][s_occ, s_vir])
                            dov -= reduce(cupy.dot,
                                          (t1_full, l1_full.T, t1_full))
                            lib.contraction(
                                'im', doo, 'ma', t1_full, 'ia', dov, beta=1.0)
                            lib.contraction(
                                'ie', t1_full, 'ae', dvv, 'ia', dov, beta=1.0, alpha=-1.0)
                            lib.contraction(
                                'ja',
                                t1_full,
                                'ia',
                                l1_full,
                                'ij',
                                doo,
                                beta=1.0,
                                alpha=-1.0)
                            lib.contraction(
                                'ia', t1_full, 'ib', l1_full, 'ab', dvv, beta=1.0)

                            doo += doo.T
                            dvv += dvv.T
                            # dov *= 2

                            RDM1_list[th_ind_tmp][s_occ, s_vir] = dov.get()
                            RDM1_list[th_ind_tmp][s_vir, s_occ] = dov.get().T
                            RDM1_list[th_ind_tmp][s_occ, s_occ] = doo.get()
                            RDM1_list[th_ind_tmp][s_vir, s_vir] = dvv.get()

                        self.RDM1[:] = RDM1_list
                        self.RDM1.flush()
                        self.LG.logger.info('Done!')
                        self.PR.recorder['build_RDM'] = True
                        self.PR.save()

                        del t1_full, l1_full, doo, dvv, dov, RDM1_list
                    else:
                        s_vir = slice(self.nocc_full, self.nao)
                        s_occ = slice(0, self.nocc_full)
                        for th_ind_tmp in range(len(self.threshold)):
                            self.RDM1[th_ind_tmp][s_occ,
                                                  s_occ] += self.RDM1[th_ind_tmp][s_occ,
                                                                                  s_occ].T
                            self.RDM1[th_ind_tmp][s_vir,
                                                  s_vir] += self.RDM1[th_ind_tmp][s_vir,
                                                                                  s_vir].T

                        self.RDM1.flush()
                        self.LG.logger.info('Done!')
                        self.PR.recorder['build_RDM'] = True
                        self.PR.save()

                if False in self.PR.recorder['RDM']:
                    self.LG.logger.info(
                        "There is/are a/some cluster 1-RDM which not be solve! Check main-recoder!")
                    sys.exit()

                for equi_frag_group in self.RDM_solved_group:
                    corr_cluster = self.corr_energy_pair[:, :, equi_frag_group].sum(
                        axis=(1, 2))
                    if self.cheat_th is not None:
                        if equi_frag_group[0] in list(self.cheat_th.keys()):
                            if type(
                                    self.cheat_th[equi_frag_group[0]][0]) is int:
                                self.LG.logger.info('threshold : %s, Correlation energy '
                                                    'contribution of fragment %s : %s' % (
                                                        numpy.asarray(self.cheat_th[equi_frag_group[0]]),
                                                        equi_frag_group, corr_cluster.tolist()))
                            else:
                                self.LG.logger.info('threshold 1ex : %s, Correlation energy '
                                                    'contribution of fragment %s : %s' % (
                                                        numpy.log10(numpy.asarray(self.cheat_th[equi_frag_group[0]])),
                                                        equi_frag_group, corr_cluster.tolist()))
                        else:
                            if type(self.threshold[0]) is int:
                                self.LG.logger.info(
                                    "threshold : %s, Correlation energy contribution of fragment %s : %s" %
                                    (numpy.asarray(
                                        self.threshold),
                                        equi_frag_group,
                                        corr_cluster.tolist()))
                            else:
                                self.LG.logger.info(
                                    "threshold 1ex : %s, Correlation energy contribution of fragment %s : %s" %
                                    (numpy.log10(
                                        numpy.asarray(
                                            self.threshold)).tolist(),
                                        equi_frag_group,
                                        corr_cluster.tolist()))
                    else:
                        if type(self.threshold[0]) is int:
                            self.LG.logger.info(
                                "threshold : %s, Correlation energy contribution of fragment %s : %s" %
                                (numpy.asarray(
                                    self.threshold),
                                    equi_frag_group,
                                    corr_cluster.tolist()))
                        else:
                            self.LG.logger.info(
                                "threshold 1ex : %s, Correlation energy contribution of fragment %s : %s" %
                                (numpy.log10(
                                    numpy.asarray(
                                        self.threshold)).tolist(),
                                    equi_frag_group,
                                    corr_cluster.tolist()))
                    self.rdm1_energy += corr_cluster
                self.PR.recorder['1-RDM_energy'] = self.rdm1_energy.tolist()
                self.PR.save()

                Fock_MO = cupy.asarray(
                    numpy.load(
                        os.path.join(
                            self.PR.filepath,
                            'Fock_MO.npy')))
                RDM1_list = cupy.asarray(self.RDM1)
                RDM1_e_corr = lib.contraction(
                    'Kij', RDM1_list, 'ij', Fock_MO, 'K').get()

                self.LG.logger.info(
                    f'The 1-RDM correlation energy calculate on-the-fly: {self.rdm1_energy.tolist()}')
                self.LG.logger.info(
                    f'The 1-RDM correlation energy calculate from 1-RDM: {RDM1_e_corr.tolist()}')
                if not self.if_l2:
                    self.LG.logger.info('Above 2 energy should be the same!')

                self.rdm1_energy = RDM1_e_corr

                rdm2_energy = 0
                for equi_i_group in self.equi_part:

                    for clu_i in equi_i_group:
                        if self.fragments[clu_i]['equivalent_operator'] == [
                                'main']:
                            break

                    with open(self.PR.filepath + '/Cluster/Cluster_%s/cluster_recorder' % clu_i, 'r') as jsonfile:
                        try:
                            cumulant_energy = numpy.asarray(
                                json.load(jsonfile)['cumulant_energy']).copy()
                            if len(cumulant_energy) != len(self.threshold):
                                self.LG.logger.info(
                                    'The len(cumulant_energy) and len(threshold) does '
                                    'not match which are %s, %s in group %s' %
                                    (len(cumulant_energy), len(
                                        self.threshold), equi_i_group))
                                sys.exit()

                            for tmp_i in range(len(cumulant_energy) - 1):
                                diff_tmp = abs(
                                    cumulant_energy[tmp_i + 1] - cumulant_energy[tmp_i])
                                if diff_tmp > 1:
                                    self.LG.logger.info(
                                        'The energy difference between threshold is larger than '
                                        '1 in group %s which does not make sense in normal.' %
                                        (equi_i_group))

                            rdm2_energy += cumulant_energy * len(equi_i_group)
                        except Exception as e:
                            self.LG.logger.info(
                                'Get cumulant_energy failed in fragment group %s' %
                                (equi_i_group))
                            self.LG.logger.info(e)
                            self.LG.logger.info(traceback.format_exc())
                            sys.exit()

                self.LG.logger.info(
                    '-------------------------------------------------')
                self.LG.logger.info(
                    'Correlation energy from 1-RDM: %s' %
                    self.rdm1_energy.tolist())
                self.LG.logger.info(
                    'Correlation energy from 2-RDM: %s' %
                    rdm2_energy.tolist())
                self.LG.logger.info(
                    'Correlation energy from RDM: %s' %
                    (rdm2_energy + self.rdm1_energy).tolist())
                self.LG.logger.info(
                    'Correlation energy from CI-coefficients: %s' %
                    self.frag_CE)

                if 'CCSD' in self.electronic_structure_solver.__name__ and self.in_situ_T:
                    self.LG.logger.info('====================================')
                    self.LG.logger.info(f'CCSD(T) correction is {e_ccsd_t}')
                    self.LG.logger.info(
                        'Correlation energy from RDM with (T) correction: %s' %
                        (rdm2_energy + self.rdm1_energy + e_ccsd_t).tolist())
                    self.LG.logger.info(
                        'Correlation energy from CI-coefficients with (T) correction: %s' %
                        (numpy.asarray(self.frag_CE) + numpy.asarray(e_ccsd_t)).tolist())
                    self.LG.logger.info('====================================')

                self.LG.logger.info(
                    '-------------------------------------------------')
                self.energy = numpy.asarray(
                    self.energy) + self.rdm1_energy + rdm2_energy

                if 'CCSD' in self.electronic_structure_solver.__name__ and self.in_situ_T:
                    self.energy += e_ccsd_t
                self.LG.logger.info(
                    'Total RDM energy: %s' %
                    (self.energy.tolist()))

            else:
                self.LG.logger.info(
                    '-------------------------------------------------')
                self.LG.logger.info(
                    'Correlation energy from CI-coefficients: %s' %
                    numpy.asarray(self.frag_CE).tolist())
                if 'CCSD' in self.electronic_structure_solver.__name__ and self.in_situ_T:
                    self.LG.logger.info('====================================')
                    self.LG.logger.info(f'CCSD(T) correction is {e_ccsd_t}')
                    self.LG.logger.info(
                        'Correlation energy from CI-coefficients with (T) correction: %s' %
                        (numpy.asarray(self.frag_CE) + numpy.asarray(e_ccsd_t)).tolist())
                    self.LG.logger.info('====================================')

                self.energy = numpy.asarray(self.energy) + self.frag_CE
                self.LG.logger.info(
                    '-------------------------------------------------')
                if 'CCSD' in self.electronic_structure_solver.__name__ and self.in_situ_T:
                    self.energy += e_ccsd_t
                self.LG.logger.info(
                    'Total energy: %s' %
                    (self.energy.tolist()))

            self.PR.recorder['energy'] = self.energy.tolist()
            self.PR.save()

            self.LG.logger.info(
                '-------------- All process finished. Start to clean up!')
            if self.RDM:
                if False not in self.PR.recorder['RDM']:
                    def delete_folders(path, parent_pattern, child_pattern):
                        for root, dirs, _ in os.walk(path):
                            for dir in fnmatch.filter(dirs, parent_pattern):
                                cluster_folder = os.path.join(root, dir)
                                for subroot, subdirs, _ in os.walk(
                                        cluster_folder):
                                    for subdir in fnmatch.filter(
                                            subdirs, child_pattern):
                                        shutil.rmtree(
                                            os.path.join(subroot, subdir))
                                        pass

                    delete_folders(
                        self.PR.filepath + 'Cluster/', 'Cluster_*', 'th*')

            if self.PR.recorder['Cluster'] == [True] * len(equi_part):
                try:
                    os.remove(j2c_file)
                except BaseException:
                    pass

                try:
                    os.remove(self.PR.recorder['low_level_info_class'])
                except BaseException:
                    pass

                def delete_files(path, parent_pattern, child_pattern):
                    for root, dirs, _ in os.walk(path):
                        for dir in fnmatch.filter(dirs, parent_pattern):
                            cluster_folder = os.path.join(root, dir)
                            try:
                                os.remove(
                                    os.path.join(
                                        cluster_folder,
                                        child_pattern))
                            except BaseException:
                                pass

                delete_files(
                    self.PR.filepath
                    + 'Cluster/',
                    'Cluster_*',
                    'bath_orb_new')

            self.PR.recorder['ProgramDone'] = True
            self.PR.save()

            return self.energy, self.used_orb_num

        else:
            orb_list, low_level_info_class_add, cluster_path = comm.bcast(
                None, root=0)
            high_level_frag = high_level_processing(
                orb_list,
                low_level_info_class_add,
                self.electronic_structure_solver,
                self.threshold,
                cluster_path)
            high_level_frag.RDM = self.RDM
            high_level_frag.fragments = self.fragments
            high_level_frag.equivalent_list = self.equivalent_list
            high_level_frag.cheat_th = self.cheat_th
            high_level_frag.eri = self.eri_file
            high_level_frag.in_situ_T = self.in_situ_T
            try:
                high_level_frag.logfile = cluster_path
            except BaseException:
                pass

            while True:
                part_list = comm.recv(source=0, tag=rank)
                if part_list:
                    if high_level_frag.vhfopt is None and high_level_frag.eri is None:
                        f = open(low_level_info_class_add, 'rb')
                        low_level_info = pickle.loads(f.read())
                        f.close()
                        vhfopt = VHFOpt3c(
                            low_level_info.mol_full, low_level_info.auxmol, 'int2e')
                        vhfopt.build(
                            group_size=MIN_GROUP_SIZE,
                            aux_group_size=MIN_GROUP_SIZE)
                        try:
                            vhfopt.auxcoeff = vhfopt.auxcoeff.get()
                        except BaseException:
                            pass

                        try:
                            vhfopt.coeff = vhfopt.auxcoeff.get()
                        except BaseException:
                            pass

                        high_level_frag.vhfopt = vhfopt

                        del low_level_info

                    try:
                        frag_corr_test, used_orb_num_temp_test = high_level_frag.kernel(
                            part_list[0])
                        comm.send(
                            [frag_corr_test, used_orb_num_temp_test], dest=0, tag=0)
                    except Exception as e:
                        high_level_frag.LG.logger.info(e)
                        high_level_frag.LG.logger.info(traceback.format_exc())
                        comm.send([False, False], dest=0, tag=0)
                else:
                    print('Node %s break the solver loop!' % (rank))
                    break
            try:
                del high_level_frag.vhfopt, vhfopt
            except BaseException:
                pass
            print('Node %s start the 1-RDM loop!' % (rank))
            t2_y_buffer_pool = cupyx.zeros_pinned(
                (self.t2_y_buffer_pool_size, ), dtype='float64')
            while True:
                params = comm2.recv(source=0, tag=rank)
                print('Node %s recv the params from node 0!' % rank)
                if params is False:
                    print('Node %s break in 1-RDM loop!' % rank)
                    break
                else:
                    PR, LG, fragments, equi_part, equi_pair_group, threshold_list, if_l2 = params

                    corr_energy_list, doo_list, dvv_list, dov_list = make_RDM1_equi_pair_group(
                        PR, LG, fragments, equi_part, equi_pair_group,
                        threshold_list, cheat_th=self.cheat_th,
                        t2_y_buffer_pool=t2_y_buffer_pool, if_l2=if_l2)
                    comm2.send([corr_energy_list, doo_list,
                               dvv_list, dov_list], dest=0, tag=0)

            del t2_y_buffer_pool
            sys.exit()


def task(high_level_frag, part_list, conn):
    try:
        frag_corr_test, used_orb_num_temp_test = high_level_frag.kernel(
            part_list[0])
        conn.send([frag_corr_test, used_orb_num_temp_test])
    except Exception as e:
        high_level_frag.LG.logger.info(e)
        high_level_frag.LG.logger.info(traceback.format_exc())
        conn.send([False, False])


def task_rdm(params, cheat_th, conn):
    PR, LG, fragments, equi_part, equi_pair_group, threshold_list, t2_y_buffer_pool, if_l2 = params

    corr_energy_list, doo_list, dvv_list, dov_list = make_RDM1_equi_pair_group(
        PR, LG, fragments, equi_part, equi_pair_group, threshold_list,
        cheat_th=cheat_th, t2_y_buffer_pool=t2_y_buffer_pool, if_l2=if_l2)

    conn.send([corr_energy_list, doo_list, dvv_list, dov_list])
